Skip to main content

Adding a New Policy

Overview

This guide explains how to add a new policy to LW-BenchHub. All policies inherit from the BasePolicy class, which provides a standardized interface for policy deployment and evaluation.

Policy Framework Architecture

Base Policy Class

The BasePolicy class (policy/base.py) defines the common interface that all policies must implement:

class BasePolicy(ABC):
"""Base Policy Class - All policies should inherit from this class"""

def __init__(self, usr_args: Dict[str, Any]):
"""Initialize policy with user arguments"""

@abstractmethod
def get_model(self, usr_args: Dict[str, Any]) -> Any:
"""Load and initialize the policy model"""

@abstractmethod
def get_action(self) -> Any:
"""Get action from the policy model"""

@abstractmethod
def eval(self, task_env, observation, usr_args, video_writer) -> bool:
"""Evaluate policy on a task"""

@abstractmethod
def reset_model(self) -> None:
"""Reset model state between episodes"""

Provided Utility Methods

The base class provides several utility methods that you can use:

  • encode_obs(observation): Preprocess observation data (handles tensor conversion and reshaping)
  • add_video_frame(video_writer, obs, camera_key): Add frames to video recording
  • step_environment(task_env, action, usr_args): Execute environment step with action mapping support
  • get_instruction(): Get task instruction from user arguments

Step-by-Step Guide: Adding a New Policy

Step 1: Create Policy Directory

Create a new directory under policy/ for your policy:

policy/
├── base.py
├── GR00T/
├── PI/
└── YourPolicy/ # Your new policy
├── your_policy.py
└── deploy_policy.yml

Step 2: Implement Your Policy Class

Create your_policy.py that inherits from BasePolicy:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Your Policy Implementation
"""
import torch
import sys
import os
from typing import Dict, Any
from policy.base import BasePolicy

# Add current directory to path if needed
current_file_path = os.path.abspath(__file__)
parent_directory = os.path.dirname(current_file_path)
sys.path.append(parent_directory)

# Import your policy library
try:
from your_policy_library import YourPolicyModel
except ImportError as e:
print(f"Your policy library not found: {e}")


class YourPolicy(BasePolicy):
"""Your Policy Implementation"""

def __init__(self, usr_args: Dict[str, Any]):
"""Initialize your policy"""
super().__init__(usr_args)

def get_model(self, usr_args: Dict[str, Any]):
"""
Load and initialize your policy model

This method is called during initialization.
Load your model checkpoint and set up any required configurations.
"""
# Extract configuration from usr_args
checkpoint = usr_args.get("checkpoint") or usr_args.get("ckpt_setting")
observation_config = usr_args.get("observation_config", {})

# Set up device
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Store configurations
self.observation_config = observation_config

# Load your model
self.model = YourPolicyModel(
checkpoint=checkpoint,
device=self.device
)

print("Successfully loaded your policy model!")

def encode_obs(self, observation: Dict[str, Any]) -> Dict[str, Any]:
"""
Encode observation into format expected by your model

Override this method if you need custom observation preprocessing.
"""
# Use parent's encode_obs for standard preprocessing
observation = super().encode_obs(observation, transpose=True, keep_dim_env=False)

# Build observation window for your model
observation = self._build_observation_window(observation)

return observation

def _build_observation_window(self, obs: Dict[str, Any]) -> Dict[str, Any]:
"""
Build observation window with custom mapping

Map observation keys to your model's expected input format.
"""
custom_mapping = self.observation_config.get("custom_mapping", {})
obs_window = {"instruction": self.instruction}

# Map observations according to custom_mapping
for key, mapping in custom_mapping.items():
if isinstance(mapping, dict):
# Handle nested mapping
obs_window[key] = {k: obs[v] for k, v in mapping.items()}
else:
# Direct mapping
obs_window[key] = obs[mapping]

return obs_window

def get_action(self) -> torch.Tensor:
"""
Get action from your policy model

Returns:
Action tensor to execute in the environment
"""
# Get action from your model
action = self.model.predict(self.observation_window)

# Post-process if needed
if isinstance(action, np.ndarray):
action = torch.from_numpy(action).float().to(self.device)

return action

def eval(self, task_env: Any, observation: Dict[str, Any],
usr_args: Dict[str, Any], video_writer: Any) -> bool:
"""
Evaluate your policy on a task

Args:
task_env: Task environment (RemoteEnv instance)
observation: Initial observation
usr_args: User arguments (contains time_out_limit, record_camera, etc.)
video_writer: Video writer for recording

Returns:
Whether the task was completed successfully
"""
terminated = False

# Main evaluation loop
for step in range(usr_args['time_out_limit']):
# Encode observation
self.observation_window = self.encode_obs(observation)

# Get action from policy
actions = self.get_action()

# Execute action(s)
# If your policy outputs action chunks, iterate through them
if actions.dim() > 1: # Multiple actions (chunk)
for i in range(actions.shape[0]):
observation, terminated = self.step_environment(
task_env, actions[i], usr_args
)
self.add_video_frame(
video_writer, observation, usr_args['record_camera']
)
if terminated:
return terminated
else: # Single action
observation, terminated = self.step_environment(
task_env, actions, usr_args
)
self.add_video_frame(
video_writer, observation, usr_args['record_camera']
)
if terminated:
return terminated

return terminated

def reset_model(self) -> None:
"""
Reset model state between episodes

Clear any internal state or observation buffers.
"""
self.observation_window = None
# Reset any other model-specific state
if hasattr(self.model, 'reset'):
self.model.reset()

print("Model state reset successfully")

Step 3: Create Configuration File

Create deploy_policy.yml to define policy parameters. The configuration now includes both policy settings and environment configuration:

# Policy Configuration
policy_name: 'YourPolicy' # Policy class name (must match your policy class)
seed: 0

# Model Configuration
ckpt_setting: 'path/to/checkpoint' # Path to trained policy checkpoint
instruction: "Your task instruction" # Task instruction/prompt

# Policy-specific parameters
your_custom_param1: value1
your_custom_param2: value2

# Observation Configuration
observation_config:
custom_mapping:
# Map observation keys from LW-BenchHub to your model's expected format
images/front: global_camera # Camera observations
images/wrist: hand_camera
state: joint_pos # State observations
action: joint_target_pos

# Evaluation Configuration
record_camera: ["global_camera", "hand_camera"]
time_out_limit: 500
height: 480 # Camera image height
width: 480 # Camera image width

# Environment Configuration (sent to server via attach())
env_cfg:
task: YourTask # Task name
robot: LeRobot-AbsJointGripper-RL # Robot type
layout: robocasakitchen # Scene layout
scene_backend: robocasa # Scene backend
task_backend: robocasa # Task backend
device: cuda:0 # Device for simulation
num_envs: 1 # Number of parallel environments
enable_cameras: true # Enable camera observations
usd_simplify: false # USD simplification
video: false # Record video in environment
seed: 42 # Random seed
for_rl: false # RL mode (false for policy evaluation)
variant: Visual # Observation variant (Visual/State)
concatenate_terms: false # Concatenate observation terms
distributed: false # Multi-GPU training mode

Step 4: Register Your Policy

Add your policy to policy/__init__.py:

from policy.YourPolicy.your_policy import YourPolicy

Common Patterns and Best Practices

1. Observation Preprocessing

Different policies may require different observation formats:

def encode_obs(self, observation: Dict[str, Any]) -> Dict[str, Any]:
# Option 1: Standard preprocessing (used by PI)
observation = super().encode_obs(observation, transpose=True, keep_dim_env=False)

# Option 2: Keep environment dimension (used by GR00T)
observation = super().encode_obs(observation, transpose=False, keep_dim_env=True)

# Then build your custom observation window
observation = self._build_observation_window(observation)
return observation

2. Action Chunking

If your policy predicts multiple future actions:

def eval(self, task_env, observation, usr_args, video_writer):
for _ in range(usr_args['time_out_limit']):
observation = self.encode_obs(observation)
actions = self.get_action() # Shape: (chunk_size, action_dim)

# Execute all actions in the chunk
for i in range(actions.shape[0]):
observation, terminated = self.step_environment(
task_env, actions[i], usr_args
)
if terminated:
return terminated
return terminated

3. Joint Mapping

If your policy's action space differs from the robot's:

def step_environment(self, task_env, action, usr_args):
# Apply joint mapping if provided
if 'joint_mapping' in usr_args:
action = action[usr_args['joint_mapping']]

# Convert to tensor if needed
if isinstance(action, np.ndarray):
action = torch.from_numpy(action).float().cuda()

obs, _, terminated, _, _ = task_env.step(action.unsqueeze(0))
return obs, terminated

4. Error Handling

Always include proper import error handling:

try:
from your_policy_library import YourModel
except ImportError as e:
print(f"Policy library not found. Please install it first: {e}")
print("Installation: pip install your-policy-library")

Troubleshooting

Common Issues

1. Import Errors

# Add proper path management
import sys
import os
current_file_path = os.path.abspath(__file__)
parent_directory = os.path.dirname(current_file_path)
sys.path.append(parent_directory)

2. Observation Shape Mismatch

# Debug observation shapes
def encode_obs(self, observation):
print(f"Input obs keys: {observation['policy'].keys()}")
for k, v in observation['policy'].items():
print(f"{k}: {v.shape if torch.is_tensor(v) else type(v)}")

observation = super().encode_obs(observation)
# ... rest of encoding

3. Action Space Mismatch

# Verify action dimensions
def get_action(self):
action = self.model.predict(self.observation_window)
expected_dim = 7 # Your robot's action dimension
assert action.shape[-1] == expected_dim, \
f"Action dim mismatch: got {action.shape[-1]}, expected {expected_dim}"
return action

4. Environment Not Attached

# Make sure to call attach() before using the environment
env = RemoteEnv.make(address=('127.0.0.1', 50000), authkey=b'lightwheel')
env.attach(env_cfg) # Don't forget this!
obs, _ = env.reset()

Summary

To add a new policy to LW-BenchHub:

  1. Create a new directory under policy/
  2. Implement your policy class inheriting from BasePolicy
  3. Override the four abstract methods: get_model, get_action, eval, reset_model
  4. Create a configuration YAML file with policy parameters and env_cfg section
  5. Register your policy in policy/__init__.py
  6. Test your implementation before deployment

The framework provides flexibility through:

  • Standardized interfaces via BasePolicy
  • Utility methods for common operations
  • Flexible observation preprocessing
  • Action chunking support
  • Custom configuration management
  • Attach/Detach architecture for dynamic environment configuration

Refer to GR00T and PI implementations for real-world examples of different policy architectures and patterns.